from datasets import load_dataset
from nltk.util import ngrams
import re

regex = re.compile('[^a-zA-Z]')
with open("scrabble.txt", "r") as f:
    words = [line.rstrip() for line in f.readlines()]
valid_words = set(words)
normallines = [' '.join(eval(line)[0].split("<|assistant|>")[-1].split()).replace("<|endoftext|>", "") for line in open("../normal_alpaca_outputs.txt", "r").readlines()]
perturbedlines = [' '.join(eval(line)[0].split("<|assistant|>")[-1].split()).replace("<|endoftext|>", "") for line in open("../perturbed_alpaca_outputs.txt", "r").readlines()]
prompts = []
eval_set = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval", trust_remote_code=True)["eval"]
for line in eval_set:
    prompts.append(line["instruction"])
perturbedsum = 0
normalsum = 0
perturbedsum = 0
for index, line in enumerate(perturbedlines):
    print(line)
    if (len(line) == 0):
        continue
    line = line.rstrip().split()
    valid_count = 0
    for word in line:
        word = regex.sub('', word)
        if (word[1:] != word[1:].lower()): # to remove words like wOrD 
            continue
        word = word.upper()
        if word in valid_words:
            valid_count += 1
    prompt = prompts[index]
    ngrams1 = set(ngrams(prompt.split(), 5, pad_right=True))
    ngrams2 = set(ngrams(line, 5, pad_right=True))
    if (len(ngrams1.intersection(ngrams2)) / len(ngrams2) > 0.5):
        continue
    perturbedsum += valid_count / len(line)

for index, line in enumerate(normallines):
    if (len(line) == 0):
        continue
    line = line.rstrip().split()
    valid_count = 0
    for word in line:
        word = regex.sub('', word)
        if (word[1:] != word[1:].lower()):
            continue
        word = word.upper()
        if word in valid_words:
            valid_count += 1
    prompt = prompts[index]
    ngrams1 = set(ngrams(prompt.split(), 5, pad_right=True))
    ngrams2 = set(ngrams(line, 5, pad_right=True))
    if (len(ngrams1.intersection(ngrams2)) / len(ngrams2) > 0.5):
        continue
    normalsum += valid_count / len(line)

print("Normal:", normalsum / len(normallines))
print("Perturbed:", perturbedsum / len(perturbedlines))
print((perturbedsum / len(perturbedlines)) / (normalsum / len(normallines)))